#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
from typing import Dict, Iterable, Optional, Type, Union

import torch
from torch import no_grad
from torchrec.metrics.metrics_config import DefaultTaskInfo
from torchrec.metrics.precision import compute_precision, PrecisionMetric
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
    metric_test_helper,
    rec_metric_gpu_sync_test_launcher,
    rec_metric_value_test_launcher,
    RecTaskInfo,
    sync_test_helper,
    TestMetric,
)


WORLD_SIZE = 4


class TestPrecisionMetric(TestMetric):
    @staticmethod
    def _get_states(
        labels: torch.Tensor,
        predictions: torch.Tensor,
        weights: torch.Tensor,
        required_inputs_tensor: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        predictions = predictions.double()
        true_pos_sum = torch.sum(weights * ((predictions >= 0.5) * labels))
        false_pos_sum = torch.sum(weights * ((predictions >= 0.5) * (1 - labels)))
        return {
            "true_pos_sum": true_pos_sum,
            "false_pos_sum": false_pos_sum,
        }

    @staticmethod
    def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor:
        return compute_precision(
            states["true_pos_sum"],
            states["false_pos_sum"],
        )


class PrecisionMetricTest(unittest.TestCase):
    target_clazz: Type[RecMetric] = PrecisionMetric
    task_name: str = "precision"

    def test_precision_unfused(self) -> None:
        rec_metric_value_test_launcher(
            target_clazz=PrecisionMetric,
            target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
            test_clazz=TestPrecisionMetric,
            metric_name=PrecisionMetricTest.task_name,
            task_names=["t1", "t2", "t3"],
            fused_update_limit=0,
            compute_on_all_ranks=False,
            should_validate_update=False,
            world_size=WORLD_SIZE,
            entry_point=metric_test_helper,
        )

    def test_precision_fused_tasks(self) -> None:
        rec_metric_value_test_launcher(
            target_clazz=PrecisionMetric,
            target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
            test_clazz=TestPrecisionMetric,
            metric_name=PrecisionMetricTest.task_name,
            task_names=["t1", "t2", "t3"],
            fused_update_limit=0,
            compute_on_all_ranks=False,
            should_validate_update=False,
            world_size=WORLD_SIZE,
            entry_point=metric_test_helper,
        )

    def test_precision_fused_tasks_and_states(self) -> None:
        rec_metric_value_test_launcher(
            target_clazz=PrecisionMetric,
            target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
            test_clazz=TestPrecisionMetric,
            metric_name=PrecisionMetricTest.task_name,
            task_names=["t1", "t2", "t3"],
            fused_update_limit=0,
            compute_on_all_ranks=False,
            should_validate_update=False,
            world_size=WORLD_SIZE,
            entry_point=metric_test_helper,
        )


class PrecisionMetricValueTest(unittest.TestCase):
    r"""This set of tests verify the computation logic of precision in several
    corner cases that we know the computation results. The goal is to
    provide some confidence of the correctness of the math formula.
    """

    def setUp(self) -> None:
        self.predictions = {"DefaultTask": None}
        self.weights = {"DefaultTask": None}
        self.labels = {"DefaultTask": None}
        self.batches = {
            "predictions": self.predictions,
            "weights": self.weights,
            "labels": self.labels,
        }
        self.precision = PrecisionMetric(
            world_size=1,
            my_rank=0,
            batch_size=100,
            tasks=[DefaultTaskInfo],
        )

    def test_calc_acc_perfect(self) -> None:
        self.predictions["DefaultTask"] = torch.Tensor(
            [[0.0001 * x for x in range(10000)] * 2]
        )
        self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000])
        self.weights["DefaultTask"] = torch.Tensor(
            [[1] * 5000 + [0] * 10000 + [1] * 5000]
        )

        expected_precision = torch.tensor([1], dtype=torch.double)
        self.precision.update(**self.batches)
        actual_precision = self.precision.compute()[
            "precision-DefaultTask|window_precision"
        ]
        torch.allclose(expected_precision, actual_precision)

    def test_calc_acc_zero(self) -> None:
        self.predictions["DefaultTask"] = torch.Tensor(
            [[0.0001 * x for x in range(10000)] * 2]
        )
        self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000])
        self.weights["DefaultTask"] = torch.Tensor(
            [[0] * 5000 + [1] * 10000 + [0] * 5000]
        )

        expected_precision = torch.tensor([0], dtype=torch.double)
        self.precision.update(**self.batches)
        actual_precision = self.precision.compute()[
            "precision-DefaultTask|window_precision"
        ]
        torch.allclose(expected_precision, actual_precision)

    def test_calc_precision_balanced(self) -> None:
        self.predictions["DefaultTask"] = torch.Tensor(
            [[0.0001 * x for x in range(10000)] * 2]
        )
        self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000])
        self.weights["DefaultTask"] = torch.ones([1, 20000])

        expected_precision = torch.tensor([0.5], dtype=torch.double)
        self.precision.update(**self.batches)
        actual_precision = self.precision.compute()[
            "precision-DefaultTask|window_precision"
        ]
        torch.allclose(expected_precision, actual_precision)


def generate_model_outputs_cases() -> Iterable[Dict[str, Union[float, torch.Tensor]]]:
    return [
        # random_inputs
        {
            "labels": torch.tensor([[1, 0, 0, 1, 1]]),
            "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
            "weights": torch.tensor([[0.3, 0.2, 0.5, 0.8, 0.7]]),
            "threshold": 0.6,
            "expected_precision": torch.tensor([0.5]),
        },
        # perfect_condition
        {
            "labels": torch.tensor([[1, 0, 0, 1, 1]]),
            "predictions": torch.tensor([[1, 0, 0, 1, 1]]),
            "weights": torch.tensor([[1] * 5]),
            "threshold": 0.6,
            "expected_precision": torch.tensor([1.0]),
        },
        # inverse_prediction
        {
            "labels": torch.tensor([[1, 0, 0, 1, 1]]),
            "predictions": torch.tensor([[0, 1, 1, 0, 0]]),
            "weights": torch.tensor([[1] * 5]),
            "threshold": 0.1,
            "expected_precision": torch.tensor([0.0]),
        },
    ]


class ThresholdValueTest(unittest.TestCase):
    """This set of tests verify the computation logic of precision with a modified threshold
    in several cases that we know the computation results.
    """

    @no_grad()
    def _test_precision_helper(
        self,
        labels: torch.Tensor,
        predictions: torch.Tensor,
        weights: torch.Tensor,
        expected_precision: torch.Tensor,
        threshold: float,
    ) -> None:
        num_task = labels.shape[0]
        batch_size = labels.shape[0]
        task_list = []
        inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = {
            "predictions": {},
            "labels": {},
            "weights": {},
        }
        for i in range(num_task):
            task_info = RecTaskInfo(
                name=f"Task:{i}",
                label_name="label",
                prediction_name="prediction",
                weight_name="weight",
            )
            task_list.append(task_info)
            # pyre-ignore
            inputs["predictions"][task_info.name] = predictions[i]
            # pyre-ignore
            inputs["labels"][task_info.name] = labels[i]
            # pyre-ignore
            inputs["weights"][task_info.name] = weights[i]

        precision = PrecisionMetric(
            world_size=1,
            my_rank=0,
            batch_size=batch_size,
            tasks=task_list,
            # pyre-ignore
            threshold=threshold,  # threshold is one of the kwargs
        )
        precision.update(**inputs)
        actual_precision = precision.compute()

        for task_id, task in enumerate(task_list):
            cur_actual_precision = actual_precision[
                f"precision-{task.name}|window_precision"
            ]
            cur_expected_precision = expected_precision[task_id].unsqueeze(dim=0)

            torch.testing.assert_close(
                cur_actual_precision,
                cur_expected_precision,
                atol=1e-4,
                rtol=1e-4,
                check_dtype=False,
                msg=f"Actual: {cur_actual_precision}, Expected: {cur_expected_precision}",
            )

    def test_precision(self) -> None:
        test_data = generate_model_outputs_cases()
        for inputs in test_data:
            try:
                self._test_precision_helper(
                    **inputs  # pyre-ignore, surpressing a type hint error
                )
            except AssertionError:
                print("Assertion error caught with data set ", inputs)
                raise


class PrecisionGPUSyncTest(unittest.TestCase):
    clazz: Type[RecMetric] = PrecisionMetric
    task_name: str = "precision"

    def test_sync_precision(self) -> None:
        rec_metric_gpu_sync_test_launcher(
            target_clazz=PrecisionMetric,
            target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
            test_clazz=TestPrecisionMetric,
            metric_name=PrecisionGPUSyncTest.task_name,
            task_names=["t1"],
            fused_update_limit=0,
            compute_on_all_ranks=False,
            should_validate_update=False,
            world_size=2,
            batch_size=5,
            batch_window_size=20,
            entry_point=sync_test_helper,
        )
